{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# NOTE:  THIS NOTEBOOK WILL TAKE ABOUT 30 MINUTES TO COMPLETE.\n",
    "\n",
    "# PLEASE BE PATIENT."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Fine-Tuning a BERT Model and Create a Text Classifier\n",
    "\n",
    "In the previous section, we've already performed the Feature Engineering to create BERT embeddings from the `reviews_body` text using the pre-trained BERT model, and split the dataset into train, validation and test files. To optimize for Tensorflow training, we saved the files in TFRecord format. \n",
    "\n",
    "Now, let’s fine-tune the BERT model to our Customer Reviews Dataset and add a new classification layer to predict the `star_rating` for a given `review_body`.\n",
    "\n",
    "![BERT Training](img/bert_training.png)\n",
    "\n",
    "As mentioned earlier, BERT’s attention mechanism is called a Transformer. This is, not coincidentally, the name of the popular BERT Python library, “Transformers,” maintained by a company called HuggingFace. We will use a variant of BERT called [DistilBert](https://arxiv.org/pdf/1910.01108.pdf) which requires less memory and compute, but maintains very good accuracy on our dataset."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import boto3\n",
    "import sagemaker\n",
    "import pandas as pd\n",
    "\n",
    "sess = sagemaker.Session()\n",
    "bucket = sess.default_bucket()\n",
    "role = sagemaker.get_execution_role()\n",
    "region = boto3.Session().region_name\n",
    "\n",
    "sm = boto3.Session().client(service_name=\"sagemaker\", region_name=region)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# _PRE-REQUISITE: You need to have succesfully run the notebooks in the `PREPARE` section before you continue with this notebook._"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%store -r processed_train_data_s3_uri"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "try:\n",
    "    processed_train_data_s3_uri\n",
    "except NameError:\n",
    "    print(\"++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++\")\n",
    "    print(\"[ERROR] Please run the notebooks in the PREPARE section before you continue.\")\n",
    "    print(\"++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(processed_train_data_s3_uri)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%store -r processed_validation_data_s3_uri"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "try:\n",
    "    processed_validation_data_s3_uri\n",
    "except NameError:\n",
    "    print(\"++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++\")\n",
    "    print(\"[ERROR] Please run the notebooks in the PREPARE section before you continue.\")\n",
    "    print(\"++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(processed_validation_data_s3_uri)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%store -r processed_test_data_s3_uri"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "try:\n",
    "    processed_test_data_s3_uri\n",
    "except NameError:\n",
    "    print(\"++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++\")\n",
    "    print(\"[ERROR] Please run the notebooks in the PREPARE section before you continue.\")\n",
    "    print(\"++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(processed_test_data_s3_uri)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%store -r max_seq_length"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "try:\n",
    "    max_seq_length\n",
    "except NameError:\n",
    "    print(\"++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++\")\n",
    "    print(\"[ERROR] Please run the notebooks in the PREPARE section before you continue.\")\n",
    "    print(\"++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(max_seq_length)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%store -r experiment_name"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "try:\n",
    "    experiment_name\n",
    "except NameError:\n",
    "    print(\"++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++\")\n",
    "    print(\"[ERROR] Please run the notebooks in the PREPARE section before you continue.\")\n",
    "    print(\"++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(experiment_name)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "%store -r trial_name"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "try:\n",
    "    trial_name\n",
    "except NameError:\n",
    "    print(\"++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++\")\n",
    "    print(\"[ERROR] Please run the notebooks in the PREPARE section before you continue.\")\n",
    "    print(\"++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(trial_name)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Specify the Dataset in S3\n",
    "We are using the train, validation, and test splits created in the previous section."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(processed_train_data_s3_uri)\n",
    "\n",
    "!aws s3 ls $processed_train_data_s3_uri/"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(processed_validation_data_s3_uri)\n",
    "\n",
    "!aws s3 ls $processed_validation_data_s3_uri/"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "print(processed_test_data_s3_uri)\n",
    "\n",
    "!aws s3 ls $processed_test_data_s3_uri/"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Specify S3 `Distribution Strategy`"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "from sagemaker.inputs import TrainingInput\n",
    "\n",
    "s3_input_train_data = TrainingInput(s3_data=processed_train_data_s3_uri, distribution=\"ShardedByS3Key\")\n",
    "s3_input_validation_data = TrainingInput(s3_data=processed_validation_data_s3_uri, distribution=\"ShardedByS3Key\")\n",
    "s3_input_test_data = TrainingInput(s3_data=processed_test_data_s3_uri, distribution=\"ShardedByS3Key\")\n",
    "\n",
    "print(s3_input_train_data.config)\n",
    "print(s3_input_validation_data.config)\n",
    "print(s3_input_test_data.config)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Setup Hyper-Parameters for Classification Layer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(max_seq_length)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "epochs = 3\n",
    "learning_rate = 0.00001\n",
    "epsilon = 0.00000001\n",
    "train_batch_size = 128\n",
    "validation_batch_size = 128\n",
    "test_batch_size = 128\n",
    "train_steps_per_epoch = 100\n",
    "validation_steps = 100\n",
    "test_steps = 100\n",
    "train_instance_count = 1\n",
    "train_instance_type = \"ml.c5.9xlarge\"\n",
    "train_volume_size = 1024\n",
    "use_xla = True\n",
    "use_amp = True\n",
    "freeze_bert_layer = False\n",
    "enable_sagemaker_debugger = True\n",
    "enable_checkpointing = False\n",
    "enable_tensorboard = True\n",
    "input_mode = \"File\"\n",
    "run_validation = True\n",
    "run_test = True\n",
    "run_sample_predictions = True"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Setup Metrics To Track Model Performance\n",
    "\n",
    "These sample log lines...\n",
    "```\n",
    "45/50 [=====>..] - ETA: 3s - loss: 0.425 - accuracy: 0.881\n",
    "50/50 [=======>] - ETA: 0s - val_loss: 0.407 - val_accuracy: 0.885\n",
    "```\n",
    "...will produce the following 4 metrics in CloudWatch:\n",
    "\n",
    "`loss` = 0.425\n",
    "\n",
    "`accuracy` = 0.881\n",
    "\n",
    "`val_loss` = 0.407\n",
    "\n",
    "`val_accuracy` = 0.885"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<img src=\"img/cloudwatch_train_metrics.png\" align=\"left\">"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "metrics_definitions = [\n",
    "    {\"Name\": \"train:loss\", \"Regex\": \"loss: ([0-9\\\\.]+)\"},\n",
    "    {\"Name\": \"train:accuracy\", \"Regex\": \"accuracy: ([0-9\\\\.]+)\"},\n",
    "    {\"Name\": \"validation:loss\", \"Regex\": \"val_loss: ([0-9\\\\.]+)\"},\n",
    "    {\"Name\": \"validation:accuracy\", \"Regex\": \"val_accuracy: ([0-9\\\\.]+)\"},\n",
    "]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Setup SageMaker Debugger\n",
    "Define Debugger Rules as deccribed here:  https://docs.aws.amazon.com/sagemaker/latest/dg/debugger-built-in-rules.html"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sagemaker.debugger import Rule\n",
    "from sagemaker.debugger import rule_configs\n",
    "from sagemaker.debugger import ProfilerRule\n",
    "from sagemaker.debugger import CollectionConfig\n",
    "from sagemaker.debugger import DebuggerHookConfig\n",
    "\n",
    "actions = rule_configs.ActionList(\n",
    "    #    rule_configs.StopTraining(),\n",
    "    #    rule_configs.Email(\"\")\n",
    ")\n",
    "\n",
    "rules = [\n",
    "    ProfilerRule.sagemaker(rule_configs.ProfilerReport()),    \n",
    "#     ProfilerRule.sagemaker(rule_configs.BatchSize()),\n",
    "#     ProfilerRule.sagemaker(rule_configs.CPUBottleneck()),\n",
    "#     ProfilerRule.sagemaker(rule_configs.GPUMemoryIncrease()),\n",
    "#     ProfilerRule.sagemaker(rule_configs.IOBottleneck()),\n",
    "#     ProfilerRule.sagemaker(rule_configs.LoadBalancing()),\n",
    "#     ProfilerRule.sagemaker(rule_configs.LowGPUUtilization()),\n",
    "#     ProfilerRule.sagemaker(rule_configs.OverallSystemUsage()),\n",
    "#     ProfilerRule.sagemaker(rule_configs.StepOutlier()),\n",
    "#     Rule.sagemaker(\n",
    "#         base_config=rule_configs.loss_not_decreasing(),\n",
    "#         rule_parameters={\n",
    "#             \"collection_names\": \"losses,metrics\",\n",
    "#             \"use_losses_collection\": \"true\",\n",
    "#             \"num_steps\": \"10\",\n",
    "#             \"diff_percent\": \"50\",\n",
    "#         },\n",
    "#         collections_to_save=[\n",
    "#             CollectionConfig(\n",
    "#                 name=\"losses\",\n",
    "#                 parameters={\n",
    "#                     \"save_interval\": \"10\",\n",
    "#                 },\n",
    "#             ),\n",
    "#             CollectionConfig(\n",
    "#                 name=\"metrics\",\n",
    "#                 parameters={\n",
    "#                     \"save_interval\": \"10\",\n",
    "#                 },\n",
    "#             ),\n",
    "#         ],\n",
    "#         actions=actions,\n",
    "#     ),\n",
    "#     Rule.sagemaker(\n",
    "#         base_config=rule_configs.overtraining(),\n",
    "#         rule_parameters={\n",
    "#             \"collection_names\": \"losses,metrics\",\n",
    "#             \"patience_train\": \"10\",\n",
    "#             \"patience_validation\": \"10\",\n",
    "#             \"delta\": \"0.5\",\n",
    "#         },\n",
    "#         collections_to_save=[\n",
    "#             CollectionConfig(\n",
    "#                 name=\"losses\",\n",
    "#                 parameters={\n",
    "#                     \"save_interval\": \"10\",\n",
    "#                 },\n",
    "#             ),\n",
    "#             CollectionConfig(\n",
    "#                 name=\"metrics\",\n",
    "#                 parameters={\n",
    "#                     \"save_interval\": \"10\",\n",
    "#                 },\n",
    "#             ),\n",
    "#         ],\n",
    "#         actions=actions,\n",
    "#     )    \n",
    "]\n",
    "\n",
    "hook_config = DebuggerHookConfig(\n",
    "    hook_parameters={\n",
    "        \"save_interval\": \"10\",  # number of steps\n",
    "        \"export_tensorboard\": \"true\",\n",
    "        \"tensorboard_dir\": \"hook_tensorboard/\",\n",
    "    }\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Specify a Debugger profiler configuration\n",
    "\n",
    "The following configuration will capture system metrics at 500 milliseconds. The system metrics include utilization per CPU, GPU, memory utilization per CPU, GPU as well I/O and network.\n",
    "\n",
    "Debugger will capture detailed profiling information from step 5 to step 15. This information includes Horovod metrics, dataloading, preprocessing, operators running on CPU and GPU."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sagemaker.debugger import ProfilerConfig, FrameworkProfile\n",
    "\n",
    "profiler_config = ProfilerConfig(\n",
    "    system_monitor_interval_millis=500,\n",
    "    framework_profile_params=FrameworkProfile(local_path=\"/opt/ml/output/profiler/\", start_step=5, num_steps=10),\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Specify Checkpoint S3 Location\n",
    "This is used for Spot Instances Training.  If nodes are replaced, the new node will start training from the latest checkpoint."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import uuid\n",
    "\n",
    "checkpoint_s3_prefix = \"checkpoints/{}\".format(str(uuid.uuid4()))\n",
    "checkpoint_s3_uri = \"s3://{}/{}/\".format(bucket, checkpoint_s3_prefix)\n",
    "\n",
    "print(checkpoint_s3_uri)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Setup Our BERT + TensorFlow Script to Run on SageMaker\n",
    "Prepare our TensorFlow model to run on the managed SageMaker service"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "!pygmentize src/tf_bert_reviews.py"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sagemaker.tensorflow import TensorFlow\n",
    "\n",
    "estimator = TensorFlow(\n",
    "    entry_point=\"tf_bert_reviews.py\",\n",
    "    source_dir=\"src\",\n",
    "    role=role,\n",
    "    instance_count=train_instance_count,\n",
    "    instance_type=train_instance_type,\n",
    "    volume_size=train_volume_size,\n",
    "    checkpoint_s3_uri=checkpoint_s3_uri,\n",
    "    py_version=\"py37\",\n",
    "    framework_version=\"2.3.1\",\n",
    "    hyperparameters={\n",
    "        \"epochs\": epochs,\n",
    "        \"learning_rate\": learning_rate,\n",
    "        \"epsilon\": epsilon,\n",
    "        \"train_batch_size\": train_batch_size,\n",
    "        \"validation_batch_size\": validation_batch_size,\n",
    "        \"test_batch_size\": test_batch_size,\n",
    "        \"train_steps_per_epoch\": train_steps_per_epoch,\n",
    "        \"validation_steps\": validation_steps,\n",
    "        \"test_steps\": test_steps,\n",
    "        \"use_xla\": use_xla,\n",
    "        \"use_amp\": use_amp,\n",
    "        \"max_seq_length\": max_seq_length,\n",
    "        \"freeze_bert_layer\": freeze_bert_layer,\n",
    "        \"enable_sagemaker_debugger\": enable_sagemaker_debugger,\n",
    "        \"enable_checkpointing\": enable_checkpointing,\n",
    "        \"enable_tensorboard\": enable_tensorboard,\n",
    "        \"run_validation\": run_validation,\n",
    "        \"run_test\": run_test,\n",
    "        \"run_sample_predictions\": run_sample_predictions,\n",
    "    },\n",
    "    input_mode=input_mode,\n",
    "    metric_definitions=metrics_definitions,\n",
    "    rules=rules,\n",
    "    debugger_hook_config=hook_config,\n",
    "    profiler_config=profiler_config,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Create the `Experiment Config`"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "experiment_config = {\"ExperimentName\": experiment_name, \"TrialName\": trial_name, \"TrialComponentDisplayName\": \"train\"}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Train the Model on SageMaker"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "estimator.fit(\n",
    "    inputs={\"train\": s3_input_train_data, \"validation\": s3_input_validation_data, \"test\": s3_input_test_data},\n",
    "    experiment_config=experiment_config,\n",
    "    wait=False,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "training_job_name = estimator.latest_training_job.name\n",
    "print(\"Training Job Name:  {}\".format(training_job_name))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from IPython.core.display import display, HTML\n",
    "\n",
    "display(\n",
    "    HTML(\n",
    "        '<b>Review <a target=\"blank\" href=\"https://console.aws.amazon.com/sagemaker/home?region={}#/jobs/{}\">Training Job</a> After About 5 Minutes</b>'.format(\n",
    "            region, training_job_name\n",
    "        )\n",
    "    )\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from IPython.core.display import display, HTML\n",
    "\n",
    "display(\n",
    "    HTML(\n",
    "        '<b>Review <a target=\"blank\" href=\"https://console.aws.amazon.com/cloudwatch/home?region={}#logStream:group=/aws/sagemaker/TrainingJobs;prefix={};streamFilter=typeLogStreamPrefix\">CloudWatch Logs</a> After About 5 Minutes</b>'.format(\n",
    "            region, training_job_name\n",
    "        )\n",
    "    )\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from IPython.core.display import display, HTML\n",
    "\n",
    "display(\n",
    "    HTML(\n",
    "        '<b>Review <a target=\"blank\" href=\"https://s3.console.aws.amazon.com/s3/buckets/{}/{}/?region={}&tab=overview\">S3 Output Data</a> After The Training Job Has Completed</b>'.format(\n",
    "            bucket, training_job_name, region\n",
    "        )\n",
    "    )\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from IPython.core.display import display, HTML\n",
    "\n",
    "display(\n",
    "    HTML(\n",
    "        '<b>Review <a target=\"blank\" href=\"https://s3.console.aws.amazon.com/s3/buckets/{}/{}/?region={}&tab=overview\">S3 Checkpoint Data</a> After The Training Job Has Completed</b>'.format(\n",
    "            bucket, checkpoint_s3_prefix, region\n",
    "        )\n",
    "    )\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "%%time\n",
    "\n",
    "estimator.latest_training_job.wait(logs=False)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Wait Until the ^^ Training Job ^^ Completes Above!"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Display Training Job Metrics"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "estimator.training_job_analytics.dataframe()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# [INFO] _Feel free to continue to the next workshop section while this notebook is running._"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%store training_job_name"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!aws s3 cp s3://$bucket/$training_job_name/output/model.tar.gz ./model.tar.gz"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!mkdir -p ./model/\n",
    "!tar -xvzf ./model.tar.gz -C ./model/"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "!saved_model_cli show --all --dir ./model/tensorflow/saved_model/0/"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "# !saved_model_cli run --dir ./model/tensorflow/saved_model/0/ --tag_set serve --signature_def serving_default \\\n",
    "#     --input_exprs 'input_ids=np.zeros((1,64));input_mask=np.zeros((1,64))'"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# View Confusion Matrix\n",
    "![Confusion Matrix](./model/metrics/confusion_matrix.png)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Analyze Debugger Rules"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "estimator.latest_training_job.rule_job_summary()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%store"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Show the Experiment Tracking Lineage"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sagemaker.analytics import ExperimentAnalytics\n",
    "\n",
    "import pandas as pd\n",
    "\n",
    "pd.set_option(\"max_colwidth\", 500)\n",
    "\n",
    "experiment_analytics = ExperimentAnalytics(\n",
    "    sagemaker_session=sess,\n",
    "    experiment_name=experiment_name,\n",
    "    metric_names=[\"validation:accuracy\"],\n",
    "    sort_by=\"CreationTime\",\n",
    "    sort_order=\"Descending\",\n",
    ")\n",
    "\n",
    "experiment_analytics_df = experiment_analytics.dataframe()\n",
    "experiment_analytics_df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sagemaker.lineage.visualizer import LineageTableVisualizer\n",
    "\n",
    "lineage_table_viz = LineageTableVisualizer(sess)\n",
    "lineage_table_viz_df = lineage_table_viz.show(training_job_name=training_job_name)\n",
    "lineage_table_viz_df"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Release Resources"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%%html\n",
    "\n",
    "<p><b>Shutting down your kernel for this notebook to release resources.</b></p>\n",
    "<button class=\"sm-command-button\" data-commandlinker-command=\"kernelmenu:shutdown\" style=\"display:none;\">Shutdown Kernel</button>\n",
    "        \n",
    "<script>\n",
    "try {\n",
    "    els = document.getElementsByClassName(\"sm-command-button\");\n",
    "    els[0].click();\n",
    "}\n",
    "catch(err) {\n",
    "    // NoOp\n",
    "}    \n",
    "</script>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%%javascript\n",
    "\n",
    "try {\n",
    "    Jupyter.notebook.save_checkpoint();\n",
    "    Jupyter.notebook.session.delete();\n",
    "}\n",
    "catch(err) {\n",
    "    // NoOp\n",
    "}"
   ]
  }
 ],
 "metadata": {
  "instance_type": "ml.t3.medium",
  "kernelspec": {
   "display_name": "Python 3 (Data Science)",
   "language": "python",
   "name": "python3__SAGEMAKER_INTERNAL__arn:aws:sagemaker:us-east-1:081325390199:image/datascience-1.0"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
